# All functions needed to run wedge extension 

## Start out by loading parameters 
wedge_param_baseline_df = DataFrame(CSV.File(joinpath(ddat,"parameters_wedges.csv"); header=1))

## Then define parameter structure 
@with_kw struct define_param_wedges
    connect_theta::Float64 = wedge_param_baseline_df.theta[1]
    connect_cost_level::Float64 = wedge_param_baseline_df.cost_level[1]
    connect_cost_elasticity::Float64 = wedge_param_baseline_df.cost_elasticity[1]
    connect_fixed_cost::Float64 = 0.0
    connect_rho::Float64 = wedge_param_baseline_df.rho[1]
    alpha_eps::Float64 = wedge_param_baseline_df.alpha_eps[1]
    beta_eps::Float64 = wedge_param_baseline_df.beta_eps[1]
    variance_eps_z::Float64 = wedge_param_baseline_df.variance_eps_z[1]
    sigma::Float64 = 1/(1-eta_tilde)
    proba_c::Float64 = baseline_parameters_df.proba_c[1]
    mean_log_z_star::Float64 = wedge_param_baseline_df.mean_log_z_star[1]
    sd_log_z_star::Float64 = wedge_param_baseline_df.sd_log_z_star[1]
    mean_tau_C::Float64 = wedge_param_baseline_df.mean_tau_C[1]
    mean_tau_NC::Float64 = wedge_param_baseline_df.mean_tau_NC[1]
    sd_tau_C::Float64 = wedge_param_baseline_df.sd_tau_C[1]
    sd_tau_NC::Float64 = wedge_param_baseline_df.sd_tau_NC[1]
    corr_wedge_z_NC::Float64 = wedge_param_baseline_df.corr_wedge_z_NC[1]
    corr_wedge_z_C::Float64 = wedge_param_baseline_df.corr_wedge_z_C[1]
end

# want to be able to index the structure
Base.getindex(x::define_param_wedges, i) = getproperty(x, fieldnames(typeof(x))[i])

# read z_star grid for wedges 
z_star_grid_wedges = DataFrame(CSV.File(joinpath(ddat,"z_star_grid_wedges.csv"); header=1))[:,1]
z_star_distrib_wedges = DataFrame(CSV.File(joinpath(ddat,"z_star_grid_wedges.csv"); header=1))[:,2]

# read relative wedge shares 
capital_wedge_share_NC = DataFrame(CSV.File(joinpath(ddat,"capital_wedge_share_NC.csv"); header=1))[:,1]
capital_wedge_share_C  = DataFrame(CSV.File(joinpath(ddat,"capital_wedge_share_C.csv"); header=1))[:,1]


# compute GE aggregates 
function compute_aggregates_wedges(; w, distrib, mass_firms, within_period_choices, tax = nothing, param) 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_capital .* distrib ) * mass_firms
    productive_labor = sum( within_period_choices.expected_labor .* distrib ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms

    # Can back out total labor demand then 
    total_labor_demand = productive_labor

    # Get total rent-seeking activities 
    total_rent_seeking = param.proba_c * sum( within_period_choices.expected_m_R_C .* distrib ) * mass_firms

    # total value added output 
    total_value_added_output = total_output - productive_intermediates - total_rent_seeking

    # HH consumption under closed economy 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates - total_rent_seeking # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0) 

    # total tax revenues and subsidies 
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* ((1-gamma_gross) .* within_period_choices.expected_revenue .- param.proba_c .* within_period_choices.expected_m_R_C) .* distrib) * mass_firms
    gross_tax_revenues = cit_revenue + vat_revenue
    total_subsidies = param.proba_c * sum( within_period_choices.expected_subsidies_C .* distrib ) * mass_firms
    net_govt_transfers = gross_tax_revenues - total_subsidies     

    # compute total HH income (not including capital)
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms
    total_HH_income = net_govt_transfers + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = net_govt_transfers + w*total_labor_demand

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_labor_demand = total_labor_demand,
        total_rent_seeking = total_rent_seeking, gross_tax_revenues = gross_tax_revenues, total_subsidies = total_subsidies,
        total_net_profits = total_net_profits,  total_HH_consumption_closed = total_HH_consumption_closed, 
        net_govt_transfers = net_govt_transfers, total_value_added_output = total_value_added_output, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits)
end 

# compute GE aggregates 
function compute_aggregates_wedges_NC(; w, distrib, mass_firms, within_period_choices, tax = nothing, param) 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_capital .* distrib ) * mass_firms
    productive_labor = sum( within_period_choices.expected_labor .* distrib ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms

    # Get value-added output 
    value_added_output = total_output - productive_intermediates

    # Can back out total labor demand then 
    total_labor_demand = productive_labor

    # HH consumption under closed economy 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0) 

    # total tax revenues and subsidies 
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* ((1-gamma_gross) .* within_period_choices.expected_revenue) .* distrib) * mass_firms
    gross_tax_revenues = cit_revenue + vat_revenue
    net_govt_transfers = gross_tax_revenues 

    # compute total HH income (not including capital)
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms
    total_HH_income = net_govt_transfers + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = net_govt_transfers + w*total_labor_demand

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_labor_demand = total_labor_demand,
        gross_tax_revenues = gross_tax_revenues,
        total_net_profits = total_net_profits,  total_HH_consumption_closed = total_HH_consumption_closed, 
        net_govt_transfers = net_govt_transfers, value_added_output = value_added_output, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits)
end 

## compute equilibrium in cf with wedges 
function find_equilibrium_cf_wedges(; guess_w, guess_Y, distrib, mass, aggr_labor_supply, tax = nothing, param, z_grid, mean_log_z, n_draws = 500, update_param_w = 0.5, update_param_Y = 0.5, crit = 1e-6, max_iter = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # length 
    n_rows = length(z_grid)

    # create results vectors 
    expected_profits = zeros(n_rows)
    expected_revenue_output = zeros(n_rows)
    expected_revenue = zeros(n_rows)
    expected_labor = zeros(n_rows)
    expected_capital = zeros(n_rows)

    # start loop of iterating on guesses for w & Y  
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and w: ", guess_w_new)
		end 

        # Step 1: create z_star_grid based on Y  
        x_bar = (guess_Y_new^(1/param.sigma))
        z_star_grid = (z_grid.^((param.sigma - 1)/param.sigma)) .* x_bar

        # solve for new: mean_log_z_star (other parameters are unchanged from change in Y)
        mean_log_z_star = ((param.sigma - 1)/param.sigma)*mean_log_z + (1/param.sigma)*log(guess_Y_new)

        # Step 2: for each row, draw wedges and solve for optimal firm choices 
        @inbounds for row = 1:n_rows

            # create wedges grid
            Random.seed!(12345)
            normal_distrib_wedge_NC = Normal(
                param.mean_tau_NC + param.corr_wedge_z_NC*(param.sd_tau_NC/param.sd_log_z_star)*(log(z_star_grid[row]) - mean_log_z_star), 
                sqrt((1-(param.corr_wedge_z_NC^2)))*param.sd_tau_NC)
            wedge_NC = exp.(rand(normal_distrib_wedge_NC, n_draws))
            proba_wedge_NC = pdf.(normal_distrib_wedge_NC, log.(wedge_NC))
            proba_wedge_NC = proba_wedge_NC ./ sum(proba_wedge_NC)

            # also draw relative capital wedges
            Random.seed!(12346 + row)
            relative_capital_wedges_NC = sample(capital_wedge_share_NC, n_draws)
        
            ### Find optimal choices conditional on w & Y 

            results_NC = compute_subsidy_wedges_baseline_NC(
                z_star_grid = ones(n_draws) .* z_star_grid[row], 
                wedges_grid = wedge_NC, 
                relative_wedge = relative_capital_wedges_NC,
                w = guess_w_new, tax = tax, 
                param = param)

            # get expected profits by taking weighted mean across all possible wedges 
            expected_profits[row] = sum(results_NC.optimal_profits_NC .* proba_wedge_NC)
            expected_revenue_output[row] = sum(results_NC.revenue_output_NC .* proba_wedge_NC)
            expected_revenue[row] = sum(results_NC.optimal_revenue .* proba_wedge_NC)
            expected_labor[row] = sum(results_NC.optimal_labor .* proba_wedge_NC)
            expected_capital[row] = sum(results_NC.optimal_capital .* proba_wedge_NC)  

        end 

        # Aggregate up total (enforcing equilibrium distribution + mass) 
        total_labor_demand = sum(expected_labor .* distrib) .* mass 
        total_output = sum(expected_revenue_output .* distrib) .* mass 
        total_revenue = sum(expected_revenue .* distrib) .* mass 

        # check labor market clearing 
        diff_L = (total_labor_demand - aggr_labor_supply)/aggr_labor_supply 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)

        # check total output versus output guess 
        diff_Y = (total_output - guess_Y_new)/guess_Y_new 
        println("Difference output is: ", diff_Y)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                        
        # update diff & iterate 
        diff = abs(diff_Y) + abs(diff_L) 
        iter = iter + 1 
    end 

    ### compute aggregates 
    profit_results = (
        expected_profits = expected_profits, 
        expected_revenue_output = expected_revenue_output,
        expected_revenue = expected_revenue, 
        expected_labor = expected_labor,
        expected_capital = expected_capital)

    aggregates = compute_aggregates_wedges_NC(
        w = guess_w_new, 
        distrib = distrib, 
        mass_firms = mass, 
        within_period_choices = profit_results, 
        tax = tax, 
        param = param) 

    # return objects
    return (w = guess_w_new, Y = guess_Y_new, aggregates = aggregates)
end 

## compute equilibrium in cf with wedges 
function find_equilibrium_baseline_wedges(; guess_w, guess_Y, distrib, mass, aggr_labor_supply, tax = nothing, param, z_grid, mean_log_z, n_draws = 500, update_param_w = 0.5, update_param_Y = 0.5, crit = 1e-6, max_iter = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # length 
    n_rows = length(z_grid)

    # create results vectors 
    expected_profits = zeros(n_rows)
    expected_revenue_output = zeros(n_rows)
    expected_revenue = zeros(n_rows)
    expected_labor = zeros(n_rows)
    expected_capital = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)

    # start loop of iterating on guesses for w & Y  
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and w: ", guess_w_new)
		end 

        # Step 1: create z_star_grid based on Y  
        x_bar = (guess_Y_new^(1/param.sigma))
        z_star_grid = (z_grid.^((param.sigma - 1)/param.sigma)) .* x_bar

        # solve for new: mean_log_z_star (other parameters are unchanged from change in Y)
        mean_log_z_star = ((param.sigma - 1)/param.sigma)*mean_log_z + (1/param.sigma)*log(guess_Y_new)

        # Step 2: for each row, draw wedges and solve for optimal firm choices 
        @inbounds for row = 1:n_rows

            ### Substep 1: Draw epsilon conditional on z_star & the probability 
            Random.seed!(12345)
            normal_distrib_eps = Normal(param.alpha_eps + param.beta_eps*log(z_star_grid[row]), sqrt(param.variance_eps_z))
            epsilon = rand(normal_distrib_eps, n_draws)
            proba_epsilon = pdf.(normal_distrib_eps, epsilon)

            # normalize probability to make sure it sums to 1  
            proba_epsilon = proba_epsilon ./ sum(proba_epsilon)

            ### Substep 2: Draw wedge conditional on z_star & the probability for C 
            Random.seed!(12345)
            normal_distrib_wedge_C = Normal(
                param.mean_tau_C + param.corr_wedge_z_C*(param.sd_tau_C/param.sd_log_z_star)*(log(z_star_grid[row]) - mean_log_z_star), 
                sqrt((1-(param.corr_wedge_z_C^2)))*param.sd_tau_C) 
            wedge_C = exp.(rand(normal_distrib_wedge_C, n_draws))
            proba_wedge_C = pdf.(normal_distrib_wedge_C, log.(wedge_C))
            proba_wedge_C = proba_wedge_C ./ sum(proba_wedge_C) # normalize probability to make sure it sums to 1

            # also draw relative capital wedges
            Random.seed!(12346 + row)
            relative_capital_wedges_C = sample(capital_wedge_share_C, n_draws) 
            
            ## Get joint probability 
            proba_joint_C = (proba_wedge_C .* proba_epsilon) ./ sum(proba_wedge_C .* proba_epsilon)
            
            ### Substep 2: Draw wedge conditional on z_star & the probability for NC 
            Random.seed!(12345)
            normal_distrib_wedge_NC = Normal(
                param.mean_tau_NC + param.corr_wedge_z_NC*(param.sd_tau_NC/param.sd_log_z_star)*(log(z_star_grid[row]) - mean_log_z_star), 
                sqrt((1-(param.corr_wedge_z_NC^2)))*param.sd_tau_NC)
            wedge_NC = exp.(rand(normal_distrib_wedge_NC, n_draws))
            proba_wedge_NC = pdf.(normal_distrib_wedge_NC, log.(wedge_NC))
            proba_wedge_NC = proba_wedge_NC ./ sum(proba_wedge_NC) # normalize probability to make sure it sums to 1  

            # also draw relative capital wedges
            Random.seed!(12346 + row)
            relative_capital_wedges_NC = sample(capital_wedge_share_NC, n_draws)
            
            ### Substep 4: construct x_star & z_tilde 
            x_star_grid_C = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/guess_w_new)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1/wedge_C)
            x_star_grid_NC = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/guess_w_new)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1/wedge_NC)

            z_tilde_grid_C = (z_star_grid[row] .* x_star_grid_C).^(1/(1-eta_tilde))
            z_tilde_grid_NC = (z_star_grid[row] .* x_star_grid_NC).^(1/(1-eta_tilde))

            ### Substep 5: Get optimal choices conditional on w & Y 

            results_C = compute_subsidy_wedges_baseline(
                z_star_grid = ones(n_draws) .* z_star_grid[row], 
                epsilon_grid = epsilon, 
                wedges_grid = wedge_C, 
                relative_wedge = relative_capital_wedges_C,
                wage = guess_w_new,  
                tax = tax, 
                param = param) 

            # Version without epsilon, but with wedges!
            results_NC = compute_subsidy_wedges_baseline_NC(
                    z_star_grid = ones(n_draws) .* z_star_grid[row], 
                    wedges_grid = wedge_NC, 
                    relative_wedge = relative_capital_wedges_NC,
                    w = guess_w_new, 
                    tax = tax, 
                    param = param)

            # get expected profits by taking weighted mean across all possible wedges 
            expected_profits[row] = (1 - param.proba_c) * sum(results_NC.optimal_profits_NC .* proba_wedge_NC) + param.proba_c * sum(results_C.optimal_profits_C .* proba_joint_C)
            expected_revenue_output[row] = (1 - param.proba_c) * sum(results_NC.revenue_output_NC .* proba_wedge_NC) + param.proba_c * sum(results_C.revenue_output_C .* proba_joint_C)
            expected_revenue[row] = (1 - param.proba_c) * sum(results_NC.optimal_revenue .* proba_wedge_NC) + param.proba_c * sum(results_C.optimal_revenue .* proba_joint_C)
            expected_labor[row] = (1 - param.proba_c) * sum(results_NC.optimal_labor .* proba_wedge_NC) + param.proba_c * sum(results_C.optimal_labor .* proba_joint_C)
            expected_capital[row] = (1 - param.proba_c) * sum(results_NC.optimal_capital .* proba_wedge_NC) + param.proba_c * sum(results_C.optimal_capital .* proba_joint_C)
            expected_m_R_C[row] = sum(results_C.optimal_m_R .* proba_joint_C)
            expected_subsidies_C[row] = sum(results_C.optimal_subsidy .* proba_joint_C)
        end 

        # Aggregate up total (enforcing equilibrium distribution + mass) 
        total_labor_demand = sum(expected_labor .* distrib) .* mass 
        total_output = sum(expected_revenue_output .* distrib) .* mass 
        total_revenue = sum(expected_revenue .* distrib) .* mass 

        # check labor market clearing 
        diff_L = (total_labor_demand - aggr_labor_supply)/aggr_labor_supply 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)

        # check total output versus output guess 
        diff_Y = (total_output - guess_Y_new)/guess_Y_new 
        println("Difference output is: ", diff_Y)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                        
        # update diff & iterate 
        diff = abs(diff_Y) + abs(diff_L) 
        iter = iter + 1 
    end 

    ### compute aggregates 
    profit_results = (
        expected_profits = expected_profits, 
        expected_revenue_output = expected_revenue_output,
        expected_revenue = expected_revenue, 
        expected_labor = expected_labor,
        expected_capital = expected_capital, 
        expected_m_R_C = expected_m_R_C, 
        expected_subsidies_C = expected_subsidies_C
        )

    aggregates = compute_aggregates_wedges(
        w = guess_w_new, 
        distrib = distrib, 
        mass_firms = mass, 
        within_period_choices = profit_results, 
        tax = tax, 
        param = param) 

    # return objects
    return (w = guess_w_new, Y = guess_Y_new, aggregates = aggregates)
end

## compute equilibrium in cf with wedges and tax rate 
function find_equilibrium_cf_wedges_tax(; guess_w, guess_Y, guess_tax, distrib, mass, aggr_labor_supply, aggr_govt_spending, param, z_grid, mean_log_z, n_draws = 500, update_param_w = 0.5, update_param_Y = 0.5, update_param_tax = 0.2, crit = 1e-6, max_iter = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    guess_tax_new = copy(guess_tax)

    # length 
    n_rows = length(z_grid)

    # create results vectors 
    expected_profits = zeros(n_rows)
    expected_revenue_output = zeros(n_rows)
    expected_revenue = zeros(n_rows)
    expected_labor = zeros(n_rows)
    expected_capital = zeros(n_rows)

    # start loop of iterating on guesses for w & Y  
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and w: ", guess_w_new, " and tax: ", guess_tax_new)
		end 

        # Step 1: create z_star_grid based on Y  
        x_bar = (guess_Y_new^(1/param.sigma))
        z_star_grid = (z_grid.^((param.sigma - 1)/param.sigma)) .* x_bar

        # solve for new: mean_log_z_star (other parameters are unchanged from change in Y)
        mean_log_z_star = ((param.sigma - 1)/param.sigma)*mean_log_z + (1/param.sigma)*log(guess_Y_new)

        # Step 2: for each row, draw wedges and solve for optimal firm choices 
        @inbounds for row = 1:n_rows

            # create wedges grid
            Random.seed!(12345)
            normal_distrib_wedge_NC = Normal(
                param.mean_tau_NC + param.corr_wedge_z_NC*(param.sd_tau_NC/param.sd_log_z_star)*(log(z_star_grid[row]) - mean_log_z_star), 
                sqrt((1-(param.corr_wedge_z_NC^2)))*param.sd_tau_NC)
            wedge_NC = exp.(rand(normal_distrib_wedge_NC, n_draws))
            proba_wedge_NC = pdf.(normal_distrib_wedge_NC, log.(wedge_NC))
            proba_wedge_NC = proba_wedge_NC ./ sum(proba_wedge_NC)

            # also draw relative capital wedges
            Random.seed!(12346 + row)
            relative_capital_wedges_NC = sample(capital_wedge_share_NC, n_draws)
        
            ### Find optimal choices conditional on w & Y 

            results_NC = compute_subsidy_wedges_baseline_NC(
                z_star_grid = ones(n_draws) .* z_star_grid[row], 
                wedges_grid = wedge_NC, 
                relative_wedge = relative_capital_wedges_NC,
                w = guess_w_new, tax = guess_tax_new, 
                param = param)

            # get expected profits by taking weighted mean across all possible wedges 
            expected_profits[row] = sum(results_NC.optimal_profits_NC .* proba_wedge_NC)
            expected_revenue_output[row] = sum(results_NC.revenue_output_NC .* proba_wedge_NC)
            expected_revenue[row] = sum(results_NC.optimal_revenue .* proba_wedge_NC)
            expected_labor[row] = sum(results_NC.optimal_labor .* proba_wedge_NC)
            expected_capital[row] = sum(results_NC.optimal_capital .* proba_wedge_NC)  

        end 

        ### compute aggregates 
        profit_results = (
            expected_profits = expected_profits, 
            expected_revenue_output = expected_revenue_output,
            expected_revenue = expected_revenue, 
            expected_labor = expected_labor,
            expected_capital = expected_capital)

        aggregates = compute_aggregates_wedges_NC(
            w = guess_w_new, 
            distrib = distrib, 
            mass_firms = mass, 
            within_period_choices = profit_results, 
            tax = guess_tax_new, 
            param = param) 

        # Get aggregates
        total_labor_demand = aggregates.total_labor_demand
        total_output = aggregates.total_output 
        total_tax_revenue = aggregates.net_govt_transfers

        # check labor market clearing 
        diff_L = (total_labor_demand - aggr_labor_supply)/aggr_labor_supply 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)

        # check total output versus output guess 
        diff_Y = (total_output - guess_Y_new)/guess_Y_new 
        println("Difference output is: ", diff_Y)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)

        ## check tax revenue 
        diff_tax = (total_tax_revenue - aggr_govt_spending)/aggr_govt_spending
        println("Difference tax revenue is: ", diff_tax) 

        # update tax guess 
        guess_tax_old = copy(guess_tax_new)
        guess_tax_new = guess_tax_old * (1.0 - diff_tax*update_param_tax)

        # update diff & iterate 
        diff = abs(diff_Y) + abs(diff_L) + abs(diff_tax)
        iter = iter + 1 
    end 

    ### compute aggregates 
    profit_results = (
        expected_profits = expected_profits, 
        expected_revenue_output = expected_revenue_output,
        expected_revenue = expected_revenue, 
        expected_labor = expected_labor,
        expected_capital = expected_capital)

    aggregates = compute_aggregates_wedges_NC(
        w = guess_w_new, 
        distrib = distrib, 
        mass_firms = mass, 
        within_period_choices = profit_results, 
        tax = guess_tax_new, 
        param = param) 

    # return objects
    return (w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, aggregates = aggregates)
end 
